Skip to content

Ltx2.3 a2v& retake video and audio#1346

Merged
Artiprocher merged 5 commits intomodelscope:mainfrom
mi804:ltx2.3_a2v
Mar 12, 2026
Merged

Ltx2.3 a2v& retake video and audio#1346
Artiprocher merged 5 commits intomodelscope:mainfrom
mi804:ltx2.3_a2v

Conversation

@mi804
Copy link
Collaborator

@mi804 mi804 commented Mar 11, 2026

No description provided.

@mi804 mi804 changed the title Ltx2.3 a2v Ltx2.3 a2v& retake video and audio Mar 11, 2026
@mi804 mi804 requested a review from Copilot March 11, 2026 12:20
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds LTX-2.3 audio-to-video (A2V) and video/audio “retake” (region-based regeneration) support, along with runnable examples and documentation links for the new workflows.

Changes:

  • Add torchaudio as a dependency and introduce read_audio_with_torchaudio (+ resampling helper) in LTX2 media I/O.
  • Extend LTX2AudioVideoPipeline to support retake_video* and retake_audio* inputs via new pipeline units and inpaint-mask handling for both video and audio latents.
  • Add new LTX-2.3 TwoStage example scripts (normal + low-VRAM) and document them in README/docs tables.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
pyproject.toml Adds torchaudio dependency needed by new/existing audio functionality.
diffsynth/utils/data/media_io_ltx2.py Adds torchaudio-based audio loading + optional resampling utilities.
diffsynth/pipelines/ltx2_audio_video.py Implements audio/video retake embedding + mask-driven denoising for A2V/retake.
examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py New A2V TwoStage example.
examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py New TwoStage retake example (video + audio regions).
examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py Low-VRAM variant of A2V TwoStage example.
examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py Low-VRAM variant of retake example.
docs/en/Model_Details/LTX-2.md Adds table entries linking to A2V/Retake examples.
docs/zh/Model_Details/LTX-2.md Adds table entries linking to A2V/Retake examples.
README.md Adds table entries linking to A2V/Retake examples.
README_zh.md Adds table entries linking to A2V/Retake examples.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def __init__(self):
super().__init__(
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio", "audio_latents"),
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LTX2AudioVideoUnit_AudioRetakeEmbedder declares output_params including denoise_mask_audio and audio_latents, but process() never returns audio_latents (and only conditionally returns denoise_mask_audio when retake_audio is provided). Since output_params are used to build the pipeline dependency/update graph, this mismatch can lead to incorrect unit splitting/order analysis during training/data-processing. Align output_params with the actual returned keys (or return audio_latents if it is intended to be set here).

Suggested change
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio", "audio_latents"),
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),

Copilot uses AI. Check for mistakes.
Comment on lines +147 to +155
def read_audio_with_torchaudio(path: str, start_time: float = 0, duration: float | None = None, resample: bool = False, resample_rate: int = 48000) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame > waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = -1 if duration is None else int(duration * sample_rate + start_frame)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read_audio_with_torchaudio is annotated as returning torch.Tensor, but it actually returns a (waveform, sample_rate) tuple. Also, when duration is None the code sets end_frame = -1, which slices off the final sample (and changes the full-length audio by default). Consider updating the return type annotation to a tuple and using end_frame = None (or omitting the end index) when duration is not provided; also consider treating start_frame == waveform.shape[-1] as out-of-range if empty audio is not intended.

Suggested change
def read_audio_with_torchaudio(path: str, start_time: float = 0, duration: float | None = None, resample: bool = False, resample_rate: int = 48000) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame > waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = -1 if duration is None else int(duration * sample_rate + start_frame)
def read_audio_with_torchaudio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
) -> tuple[torch.Tensor, int]:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame >= waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = None if duration is None else int(duration * sample_rate + start_frame)

Copilot uses AI. Check for mistakes.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the capabilities of the LTX-2.3 audio-video generation pipeline by integrating new audio-to-video generation and advanced retake features for both video and audio. These additions provide users with greater control and flexibility in generating and refining multimedia content, enhancing the overall functionality and usability of the system. The changes are supported by updated documentation and practical examples.

Highlights

  • LTX-2.3 Audio-Video Generation Model Support: Added comprehensive support for the LTX-2.3 audio-video generation model, including text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting, with full inference and training functionalities.
  • Audio-to-Video (A2V) Functionality: Introduced new audio-to-video generation capabilities within the LTX-2.3 pipeline, allowing video creation from audio inputs.
  • Video and Audio Retake Features: Implemented advanced video and audio retake functionalities, enabling users to regenerate specific regions of video frames or audio based on defined time regions.
  • Updated Documentation and Examples: Included new documentation entries and example scripts to demonstrate the usage of the LTX-2.3 model's new A2V and retake features, both for standard and low-VRAM inference.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • README.md
    • Updated with a new entry announcing LTX-2.3 audio-video generation model support.
    • Added new LTX-2.3 A2V and Retake pipeline examples to the model table.
  • README_zh.md
    • Updated with a new entry announcing LTX-2.3 audio-video generation model support in Chinese.
    • Added new LTX-2.3 A2V and Retake pipeline examples to the model table in Chinese.
  • diffsynth/pipelines/ltx2_audio_video.py
    • Added LTX2AudioVideoUnit_VideoRetakeEmbedder and LTX2AudioVideoUnit_AudioRetakeEmbedder to the pipeline initialization.
    • Modified the denoise_stage method to support audio inpainting masks.
    • Extended the call method to accept retake_video, retake_video_regions, retake_audio, audio_sample_rate, and retake_audio_regions parameters.
    • Updated LTX2AudioVideoUnit_InputVideoEmbedder and LTX2AudioVideoUnit_InputAudioEmbedder to handle training mode and removed NotImplementedError for video-to-video and audio-to-video.
    • Introduced LTX2AudioVideoUnit_VideoRetakeEmbedder and LTX2AudioVideoUnit_AudioRetakeEmbedder classes for handling video and audio retake logic, including mask generation.
    • Modified LTX2AudioVideoUnit_InputImagesEmbedder to correctly apply input images to video latents, incorporating input_latents_video and denoise_mask_video.
    • Updated LTX2AudioVideoUnit_SwitchStage2 to initialize input_latents_video and denoise_mask_video for stage 2.
    • Removed initial_latents from LTX2AudioVideoUnit_LatentsUpsampler output parameters and its return value.
    • Added input_latents_audio and denoise_mask_audio parameters to model_fn_ltx2 and applied audio inpainting logic.
  • diffsynth/utils/data/media_io_ltx2.py
    • Imported torchaudio for enhanced audio processing.
    • Added resample_waveform and read_audio_with_torchaudio functions for robust audio handling.
    • Applied minor formatting fixes.
  • docs/en/Model_Details/LTX-2.md
    • Added new LTX-2.3 A2V and Retake pipeline examples to the model table.
  • docs/zh/Model_Details/LTX-2.md
    • Added new LTX-2.3 A2V and Retake pipeline examples to the model table in Chinese.
  • examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py
    • Added a new example script for LTX-2.3 audio-to-video generation using a two-stage pipeline.
  • examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py
    • Added a new example script for LTX-2.3 text-to-audio/video with retake functionality using a two-stage pipeline.
  • examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py
    • Added a new low-VRAM example script for LTX-2.3 audio-to-video generation using a two-stage pipeline.
  • examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py
    • Added a new low-VRAM example script for LTX-2.3 text-to-audio/video with retake functionality using a two-stage pipeline.
  • pyproject.toml
    • Added 'torchaudio' to the project dependencies.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces audio-to-video (A2V) and video/audio retake functionalities for the LTX-2.3 model. The changes include adding new pipeline units for handling retake video and audio, updating the main pipeline to accept new parameters, and adding corresponding example scripts. The implementation looks mostly correct, but I've found a critical issue in the pipeline unit definition that could break execution, a bug in audio slicing logic, and a minor type hint mismatch. My review includes suggestions to fix these issues.

def __init__(self):
super().__init__(
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The output_params for this PipelineUnit includes "audio_latents", which is inconsistent with LTX2AudioVideoUnit_VideoRetakeEmbedder and requires the process method to return it (which it currently doesn't). The LTX2AudioVideoUnit_InputAudioEmbedder unit, which runs later in the pipeline, is responsible for setting audio_latents. To maintain consistency and ensure correct pipeline flow, "audio_latents" should be removed from output_params here.

Suggested change
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),

resample: bool = False,
resample_rate: int = 48000,
) -> tuple[torch.Tensor, int]:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When duration is None, end_frame is set to -1. In Python slicing, using -1 as the end index excludes the last element of the tensor. To slice until the very end of the tensor, None should be used instead. This will prevent unintentionally dropping the last audio sample.

Suggested change
waveform, sample_rate = torchaudio.load(path, channels_first=True)
end_frame = None if duration is None else int(duration * sample_rate + start_frame)

resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for this function is -> torch.Tensor, but it actually returns a tuple (waveform, sample_rate). The type hint should be updated to -> tuple[torch.Tensor, int] to accurately reflect the function's output.

Suggested change
def read_audio_with_torchaudio(path: str, start_time: float = 0, duration: float | None = None, resample: bool = False, resample_rate: int = 48000) -> tuple[torch.Tensor, int]:

@Artiprocher Artiprocher merged commit 4741542 into modelscope:main Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants